Causal Random Forest

Published

November 15, 2025

Introduction

Background to causal forests based on:

  • Athey, Susan, Julie Tibshirani, and Stefan Wager. 2019. “Generalized Random Forests.” The Annals of Statistics 47 (2): 1148–78. https://doi.org/10.1214/18-AOS1709.
  • Wager, Stefan, and Susan Athey. “Estimation and inference of heterogeneous treatment effects using random forests.” Journal of the American Statistical Association 113.523 (2018): 1228-1242.

Will use the grf example to demonstrate this: https://grf-labs.github.io/grf/articles/grf_guide.html

Background

In causal analysis, we aim to estimate the causal effect \(\tau\) based on a treatment \(W\). If data come from a randomized control trial, we assume no confounders, and the effect is just:

\[ \tau = E[Y_i(1) - Y_i(0)] \]

In observational studies, we have confounders (\(X_i\)), where we need to account for their effect on both \(Y_i\) and \(W_i\) for any individual \(i\). The effect can now be estimated from the following regression:

\[ Y_i = \tau W_i + \beta X_i + \epsilon_i \] Where

  • \(\hat{\tau}\) is taken to be a good estimate of \(\tau\). This has the following assumptions:
  1. Conditional unconfoundedness. i.e. \(W_i\) is unconfounded given \(X_i\)
    • \({Y_i(1), Y_i(0)} \perp W_i | X_i\)
  2. The error is assumed random (conditonal on \(W_i\) and \(X_i\))
    • \(E[\epsilon_i|X_i,W_i] = 0\)
  3. The confounders have a linear effect on \(Y_i\)
  4. The treatment effect is constant

We can’t do anything about assumptions 1 and 2 as they are necessary to identify the model. But assumptions 3 and 4 relate to the model used and can be questioned.

Non-linear effects

For assumption 3, we can relax this through a standard semi-parametric approach:

\[ Y_i = \tau W_i + f(X_i) + \epsilon_i \]

Here, the baseline outcome for individual \(i\) is some unknown function of \(X_i\), which can be complex. The treatment is still constant and shifts the baseline estimate by \(\tau\).

The question is then how to define f() given that it is unknown. To do so, this takes advantage of the double machine learning, itself based on the Frisch-Waugh-Lovell theorem.

Frisch-Waugh-Lovell theorem

This basically states that any regression with a confounder can be replaced with a regression based on residuals. For causal analysis, we want to estimate \(Y_i = \tau W_i + \beta X_i + \epsilon_i\). \(\tau\) can be obtained in the following steps:

  1. Estimate the conditional mean equation, and residuals based on this:

\[ E(Y_i|X_i=x) = \hat{Y}_i \sim \beta_Y X_i; \\ Y_i'= Y_i - \hat{Y}_i \]

  1. Estimate the propensity equation, and residuals based on this:

\[ E(W_i|X_i=x) = \hat{W}_i \sim \beta_W X_i; \\ W_i'= W_i - \hat{W}_i \]

  1. Estimate the causal effect by regressing the residuals:

\[ Y_i' \sim \tau W_i' \]

In short, this works because the first tow models remove all variation in \(Y\) and \(W\) related to the confounder. Any remaining variance explained by \(T'\) is then causal.

This workflow outlines the steps (from http://medium.com/@med.hmamouch99/double-machine-learning-for-causal-inference-a-practical-guide-5d85b77aa586)

Frisch-Waugh-Lovell workflow

Double-machine learning

Robinson first showed that this can be used with semi-parametric models in 1988. In this paper, the FWL equations were modified to include semi-parametric functions of \(X\), \(f(X)\):

  • The conditional mean of \(Y\): \(m(x) = E(Y_i|X_i=x) = f(x) + \tau e(x)\)
  • The propensity score: \(e(x) = E(W_i|X_i=x)\)

This can then be rewritten as:

\[ Y_i - m(x) = \tau (W_i - e(x)) + \epsilon_i \]

Robinson describes this as ‘centering’: plug in estimates of \(m(x)\) and \(e(x)\) are obtained, \(Y_i\) and \(W_i\) are centered,m then the residuals are regressed together.

This was further extended by Chernozhukov (Chernozhukov et al., 2018), to allow the use of any predictive model in these equations. Essentially, to get to the final stage model, all we need are:

good predictions of the outcome conditional on \(X_i\), and the treatment assignment conditional on \(X_i\) (from grf docs)

In standard residual-on-residual approaches, we assume that the estimates of \(m(x)\) and \(e(x)\) are obtained through parametric means (e.g. OLS regression). These are replaced by machine learning models in double machine learning (DML) approaches, including causal RFs. Note that these plug-in estimates are obtained using ‘cross-fitting’. In this, the prediction of, for example, the outcome \(m(x)\) based on the confounders \(X_i\) for individual \(i\) is made with a model trained on all observations except \(i\). This avoids bias due to the (different) regularization strategies employed by the two models.

Non-constant treatment effects

In the original equation, \(\tau\) is assumed to be a constant factor across all individuals. To relax this, they use the idea of subgroups within the data set, each of which has it’s own regression, giving a value of \(\tau\) for each group ($here is still the coefficient of a linear model based on the centered outcome and treatment). The equation now becomes:

\[ Y_i = \tau(X_i) W_i + f(X_i) + \epsilon_i \]

where \(\tau (X_i)\) is the conditional average treatment effect for a given set of values of \(X_i\):

\[ E(Y_i(1) - Y_i(0)|X_i = x) \]

The next question is how to find these groups. We want to find subgroups where \(\tau\) can be assumed constant, in other words, we want to find a set of observations that we can calculate the residual-on-residual regression. Note that this is still a linear model, where the slope coefficient gives \(\tau\) for that set of observations:

\[ \tau(x) = lm(Y_i - \hat{m}^{-i}(X_i) \sim W_i - \hat{e}^{-i}(X_i), \mbox{weights} = 1(X_i \in \square (x))) \]

A better way of thinking about this may be that we are trying to build a random forest where the outcome is the slope of a regression line. The loss function prioritizes the biggest difference in slope at any split point.

CF Algorithm

  1. Fit first models (nuisance and propensity) using any standard machine learning algorithm (cross fit: use different subsets of data for the two models)
  2. Use first stage models to estimate values for the outcome (\(m(x)\)) and treatment (\(e(x)\))
  3. Calculate outcome residuals (\(Y' = Y_i - m(x)\)) and treatment residuals (\(T' = T_i - e(x)\))
  4. Fit second stage model (causal forest)
  1. Bootstrap data into in-bag (IB) and out-of-bag (OOB) sets
  2. Using the IB set, for each feature (\(X_j\)): a. Iterate across values of \(X_j\) to partition the OOB data into two sets (\(L\) and \(R\)) b. Test for imbalance (hyperparameter: 25-75% is maximum imbalance by default): skip if imbalance is greater than this c. Fit the following model in each of the two partitions:
    1. Left partition (\(L\)): \(Y'_L = \tau_L T'_L + e\)
    2. Right partition (\(R\)): \(Y'_R = \tau_R T'_R + e\)
      1. Calculate difference in treatment effect \(\delta \tau = |\tau_L - \tau_R|\)
      2. Repeat for all \(j\) features to find feature (and value) that maximizes \(\delta \tau\)
  3. make new data sets for both IB and OOB based on this split
  4. Repeat from ii. All subsequent steps will be based on \(\gt 1\) data subsets, so requires testing splits of \(X_j\) across all existing partitions

Example 1

Example of fitting a causal forest with a nonlinear relationship between X and \(\tau\):

library(grf)
library(ggplot2)

Create some data (X[,1] is a confounder, X[,2] & X[,3] have non-linear impact on outcome)

set.seed(42)
n <- 2000
p <- 10
X <- matrix(rnorm(n * p), n, p)
X_test <- matrix(0, 101, p)
X_test[, 1] <- seq(-2, 2, length.out = 101)

W <- rbinom(n, 1, 0.4 + 0.2 * (X[, 1] > 0))
Y <- pmax(X[, 1], 0) * W + X[, 2] + pmin(X[, 3], 0) + rnorm(n)

Plot X_1 and Y

plot_df = data.frame(X= X[,1], W = as.factor(W), Y = Y)
ggplot(plot_df, aes(x = X, y= Y, col = W)) +
  geom_point() +
  geom_smooth()
`geom_smooth()` using method = 'gam' and formula = 'y ~ s(x, bs = "cs")'

Fit causal forest

tau_forest <- causal_forest(X, Y, W)
tau_forest
GRF forest object of type causal_forest 
Number of trees: 2000 
Number of training samples: 2000 
Variable importance: 
    1     2     3     4     5     6     7     8     9    10 
0.704 0.037 0.034 0.042 0.033 0.030 0.028 0.025 0.032 0.036 

Predict and plot

tau_hat <- predict(tau_forest, X_test)
plot_df = data.frame(X = rep(X_test[,1], 2),
                     tau = c(pmax(0, X_test[, 1]), tau_hat$predictions),
                     label = rep(c("Truth","Pred"), each = nrow(X_test)))
ggplot(plot_df, aes(x = X, y = tau, col = label)) +
  geom_line() +
  theme_bw()

Example 1b

This uses the same data as before, but walks through a single split in a causal tree.

Step 1: nuisance model for treatment (the propensity model)

forest_W <- regression_forest(X, W, tune.parameters = "all")
W_hat <- predict(forest_W)$predictions

Step 2: nusiance model for outcome

forest_Y <- regression_forest(X, Y, tune.parameters = "all")
Y_hat <- predict(forest_Y)$predictions

Step 2b (optional): variable selection for forest

forest_Y_varimp <- variable_importance(forest_Y)
forest_Y_varimp
             [,1]
 [1,] 0.064545274
 [2,] 0.722531655
 [3,] 0.179488680
 [4,] 0.003975335
 [5,] 0.004832954
 [6,] 0.005150976
 [7,] 0.004071005
 [8,] 0.005604137
 [9,] 0.005134106
[10,] 0.004665877
tau_forest <- causal_forest(X, Y, W,
                            W.hat = W_hat, Y.hat = Y_hat,
                            tune.parameters = "all")

tau_hat <- predict(tau_forest, X_test)

plot(X_test[, 1], tau_hat$predictions, ylim = range(tau_hat$predictions, 0, 2), xlab = "x", ylab = "tau", type = "l")
lines(X_test[, 1], pmax(0, X_test[, 1]), col = 2, lty = 2)

library(animation)

breaks = seq(-2,2,by = 0.1)
nbreaks = length(breaks)
tau_df = data.frame(brks = breaks, 
                    tau1 = rep(NA, nbreaks),
                    tau2 = rep(NA, nbreaks),
                    dtau = rep(NA, nbreaks))
plot_df = data.frame(X = X[,1], 
                     W = as.factor(W), 
                     W_hat = W_hat,
                     Y = Y,
                     Y_hat = Y_hat)
plot_df$tau_hat <- predict(tau_forest)$predictions
Output at: test.gif
[1] TRUE

Pseudo-code for the full tree here:

nsplit = 10
out_df = data.frame(split = 1:nsplit,
                    X = rep(NA, nsplit),
                    tau_lo = rep(NA, nsplit),
                    tau_hi = rep(NA, nsplit))

data_list = list(plot_df)
breaks = seq(-2,2,by = 0.1)
nbreaks = length(breaks)

for (i in 1:nsplit) {
  max_dtau = max_df = max_x = max_tau1 = max_tau2 = -9999
  for (j in 1:length(data_list)) {
    tmp_dl = data_list[[j]]
    print(paste(i,j))
    tau_df = data.frame(brks = breaks, 
                        tau1 = rep(NA, nbreaks),
                        tau2 = rep(NA, nbreaks),
                        dtau = rep(NA, nbreaks))
    for (k in 1:nbreaks) {
      
      dat1 = tmp_dl |>
        dplyr::filter(X < breaks[k])
      dat2 = tmp_dl |>
        dplyr::filter(X >= breaks[k])
      
      if (nrow(dat1) > 10 & nrow(dat2) > 10) {
        mod1 = lm(Y_hat ~ W_hat, dat1)
        tau_df$tau1[k] = coef(mod1)[2]
        mod2 = lm(Y_hat ~ W_hat, dat2)
        tau_df$tau2[k] = coef(mod2)[2]
      
        tau_df$dtau[k] = abs(tau_df$tau1[k] - tau_df$tau2[k])
      }
    }
    
    rowID = which.max(tau_df$dtau)
    if (any(!is.na(tau_df$dtau))) {
      if (tau_df$dtau[rowID] > max_dtau) {
        max_tau1 = tau_df$tau1[rowID]
        max_tau2 = tau_df$tau2[rowID]
        max_dtau = tau_df$dtau[rowID]
        max_x = tau_df$brks[rowID]
        max_df = j
      }
      
    }
  }
  print(paste(max_dtau, rowID, max_df, max_tau1, max_tau2))
  out_df$X[i] = max_x
  out_df$tau_lo[i] = max_tau1
  out_df$tau_hi[i] = max_tau2
  
  new_dl = list()
  for (j in 1:length(data_list)) {
    if (j == max_df) {
      dat1 = data_list[[j]] |>
        dplyr::filter(X < max_x)
      dat2 = data_list[[j]] |>
        dplyr::filter(X >= max_x)
      new_dl = append(new_dl, list(dat1, dat2))
    } else {
      new_dl = append(new_dl, list(data_list[[j]]))
    }
  }
  
  ## Reassign data_list
  data_list = new_dl
}

Example 2

From grf docs

In this section, we walk through an example application of GRF. The data we are using is from Bruhn et al. (2016), which conducted an RCT in Brazil in which high schools were randomly assigned a financial education program (in settings like this it is common to randomize at the school level to avoid student-level interference). This program increased student financial proficiency on average. Other outcomes are considered in the paper, we’ll focus on the financial proficiency score here. A processed copy of this data, containing student-level data from around 17 000 students, is stored on the github repo, it extracts basic student characteristics, as well as additional baseline survey responses we use as covariates (two of these are aggregated into an index by the authors to assess student’s ability to save, and their financial autonomy).

library(grf)

data <- read.csv("./data/bruhn2016.csv")
Y <- data$outcome
W <- data$treatment
school <- data$school
X <- data[-(1:3)]

Around 30% have one or more missing covariates, the missingness pattern doesn’t seem to vary systematically between the treated and controls, so we’ll keep them in the analysis since GRF supports splitting on X’s with missing values.

sum(!complete.cases(X)) / nrow(X)
[1] 0.2934852
t.test(W ~ !complete.cases(X))

    Welch Two Sample t-test

data:  W by !complete.cases(X)
t = -0.3923, df = 9490.1, p-value = 0.6948
alternative hypothesis: true difference in means between group FALSE and group TRUE is not equal to 0
95 percent confidence interval:
 -0.01963191  0.01308440
sample estimates:
mean in group FALSE  mean in group TRUE 
          0.5131730           0.5164467 

Fitting causal forest

cf <- causal_forest(X, Y, W, W.hat = 0.5, clusters = school)